{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "O9I9rz0pTamX"
   },
   "source": [
    "# MNLI Diagnostic Example"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rXbD_U1_VDnw"
   },
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "tC9teoazUnW8"
   },
   "source": [
    "#### Install dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "8aU3Z9szuMU9"
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "!git clone https://github.com/nyu-mll/jiant.git\n",
    "%cd jiant\n",
    "!pip install -r requirements-no-torch.txt\n",
    "!pip install --no-deps -e ./"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "KGJcCmRzU1Qb"
   },
   "source": [
    "#### Download data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "jKCz8VksvFlN"
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "%cd /content\n",
    "# Download/preprocess MNLI and Dognostic data\n",
    "!PYTHONPATH=/content/jiant python jiant/jiant/scripts/download_data/runscript.py \\\n",
    "    download \\\n",
    "    --tasks mnli mnli_mismatched glue_diagnostics \\\n",
    "    --output_path=/content/tasks/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rQKSAhYzVIlv"
   },
   "source": [
    "## `jiant` Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "v88oXqmBvFuK"
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0, \"/content/jiant\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ibmMT7CXv1_P"
   },
   "outputs": [],
   "source": [
    "import jiant.proj.main.tokenize_and_cache as tokenize_and_cache\n",
    "import jiant.proj.main.export_model as export_model\n",
    "import jiant.proj.main.scripts.configurator as configurator\n",
    "import jiant.proj.main.runscript as main_runscript\n",
    "import jiant.shared.caching as caching\n",
    "import jiant.utils.python.io as py_io\n",
    "import jiant.utils.display as display\n",
    "import os\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "HPZHyLOlVp07"
   },
   "source": [
    "#### Download model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "K06qUGjkKWa7"
   },
   "outputs": [],
   "source": [
    "export_model.export_model(\n",
    "    hf_pretrained_model_name_or_path=\"roberta-base\",\n",
    "    output_base_path=\"./models/roberta-base\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "dV-T-8r1V0wf"
   },
   "source": [
    "#### Tokenize and cache\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "22bNWQajO4zm"
   },
   "outputs": [],
   "source": [
    "# Tokenize and cache each task\n",
    "tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(\n",
    "    task_config_path=f\"./tasks/configs/mnli_config.json\",\n",
    "    hf_pretrained_model_name_or_path=\"roberta-base\",\n",
    "    output_dir=f\"./cache/mnli\",\n",
    "    phases=[\"train\", \"val\"],\n",
    "))\n",
    "\n",
    "tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(\n",
    "    task_config_path=f\"./tasks/configs/mnli_mismatched_config.json\",\n",
    "    hf_pretrained_model_name_or_path=\"roberta-base\",\n",
    "    output_dir=f\"./cache/mnli_mismatched\",\n",
    "    phases=[\"val\"],\n",
    "))\n",
    "\n",
    "tokenize_and_cache.main(tokenize_and_cache.RunConfiguration(\n",
    "    task_config_path=f\"./tasks/configs/glue_diagnostics_config.json\",\n",
    "    hf_pretrained_model_name_or_path=\"roberta-base\",\n",
    "    output_dir=f\"./cache/glue_diagnostics\",\n",
    "    phases=[\"test\"],\n",
    "))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "iLk_X0KypUyr"
   },
   "outputs": [],
   "source": [
    "row = caching.ChunkedFilesDataCache(\"./cache/mnli/train\").load_chunk(0)[0][\"data_row\"]\n",
    "print(row.input_ids)\n",
    "print(row.tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "2n00e6Xrp1bI"
   },
   "outputs": [],
   "source": [
    "row = caching.ChunkedFilesDataCache(\"./cache/mnli_mismatched/val\").load_chunk(0)[0][\"data_row\"]\n",
    "print(row.input_ids)\n",
    "print(row.tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "cjwnG_xXCelU"
   },
   "outputs": [],
   "source": [
    "row = caching.ChunkedFilesDataCache(\"./cache/glue_diagnostics/test\").load_chunk(0)[0][\"data_row\"]\n",
    "print(row.input_ids)\n",
    "print(row.tokens)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "3MBuH19IWOr0"
   },
   "source": [
    "#### Writing a run config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "pQYtl7xTKsiP"
   },
   "outputs": [],
   "source": [
    "jiant_run_config = configurator.SimpleAPIMultiTaskConfigurator(\n",
    "    task_config_base_path=\"./tasks/configs\",\n",
    "    task_cache_base_path=\"./cache\",\n",
    "    train_task_name_list=[\"mnli\"],\n",
    "    val_task_name_list=[\"mnli\", \"mnli_mismatched\"],\n",
    "    test_task_name_list=[\"glue_diagnostics\"],\n",
    "    train_batch_size=8,\n",
    "    eval_batch_size=16,\n",
    "    epochs=0.1,\n",
    "    num_gpus=1,\n",
    ").create_config()\n",
    "display.show_json(jiant_run_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "usn2oQo8ILwi"
   },
   "source": [
    "Configure all three tasks to use an `mnli` head."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "624iqo12Gs7D"
   },
   "outputs": [],
   "source": [
    "jiant_run_config[\"taskmodels_config\"][\"task_to_taskmodel_map\"] = {\n",
    "    \"mnli\": \"mnli\",\n",
    "    \"mnli_mismatched\": \"mnli\",\n",
    "    \"glue_diagnostics\": \"mnli\",\n",
    "}\n",
    "os.makedirs(\"./run_configs/\", exist_ok=True)\n",
    "py_io.write_json(jiant_run_config, \"./run_configs/jiant_run_config.json\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "BBKkvXzdYPqZ"
   },
   "source": [
    "#### Start training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "JdwWPgjQWx6I"
   },
   "outputs": [],
   "source": [
    "run_args = main_runscript.RunConfiguration(\n",
    "    jiant_task_container_config_path=\"./run_configs/jiant_run_config.json\",\n",
    "    output_dir=\"./runs/run1\",\n",
    "    hf_pretrained_model_name_or_path=\"roberta-base\",\n",
    "    model_path=\"./models/roberta-base/model/model.p\",\n",
    "    model_config_path=\"./models/roberta-base/model/config.json\",\n",
    "    learning_rate=1e-5,\n",
    "    eval_every_steps=500,\n",
    "    do_train=True,\n",
    "    do_val=True,\n",
    "    do_save=True,\n",
    "    write_test_preds=True,\n",
    "    force_overwrite=True,\n",
    ")\n",
    "main_runscript.run_loop(run_args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "jOxAmtQmHu9E"
   },
   "outputs": [],
   "source": [
    "test_preds = torch.load(\"./runs/run1/test_preds.p\")\n",
    "test_preds[\"glue_diagnostics\"]"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "jiant MNLI Diagnostic Example",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}